model_decisions_counterfactuals <- function(par, data, prior, setup = 0){
  
  psi <- par[6]
  
  # NEW: just pull in data (don't draw a+c)
  final <- data
  
  # NEW: allow mother's prior to be different from the truth
  final <- final %>% 
    mutate(mother_prior = (!!prior))
  
  # calculate decisions according to model for each i 
  # recommendation to do invasive based on KUB result causes prior --> prior^psi everywhere on decision tree 
  # NEW: use prior (mother_prior), not truth (p_i)
  final <- final %>%
    mutate(rec = ifelse(p_i >= 1/200, 1, 0)) %>%
    mutate(p = ifelse(rec==1, p_i^psi, p_i))
  
  final <- final %>% 
    mutate(p = ifelse(rec==1, mother_prior^psi, mother_prior))
  
  final <- final %>% 
    mutate(true_pr_pos = (1-p_i)*pfp + p_i*(1-pfn) ) %>% # note: uses truth=p_i, not p (which is based on prior)
    mutate(pr_pos = (1-p)*pfp + p*(1-pfn) ) %>%
    mutate(pr_c_pos = (1-pfn)*p/pr_pos ) %>%
    mutate(pr_c_neg = pfn*p/(1-pr_pos) ) 
  
  final <- final %>% 
    mutate(ui2_Y0 = pa*a_i + (1-pa)*p*pmax(a_i, c_i) - oop2_i ) %>% 
    mutate(ui2_N0 = pmax(a_i, p*c_i) ) %>% 
    mutate(ui2_YP = pa*a_i + (1-pa)*pr_c_pos*pmax(a_i, c_i) - oop_i - oop2_i ) %>% 
    mutate(ui2_NP = pmax(a_i, pr_c_pos*c_i) - oop_i ) %>% 
    mutate(ui2_YN = pa*a_i + (1-pa)*pr_c_neg*pmax(a_i, c_i) - oop_i - oop2_i ) %>% 
    mutate(ui2_NN = pmax(a_i, pr_c_neg*c_i) - oop_i ) 
  
  final <- final %>%
    mutate(invasive_P = ifelse(round(ui2_YP, 5) >= round(ui2_NP, 5), 1, 0)) %>%
    mutate(invasive_N = ifelse(round(ui2_YN, 5) >= round(ui2_NN, 5), 1, 0)) %>%
    mutate(invasive_0 = ifelse(round(ui2_Y0, 5) >= round(ui2_N0, 5), 1, 0)) 
  
  final <- final %>%
    mutate(ui1_Y = pr_pos*pmax(ui2_YP, ui2_NP) + (1-pr_pos)*pmax(ui2_YN, ui2_NN) ) %>%
    mutate(ui1_N = pmax(ui2_Y0, ui2_N0) )
  
  final <- final %>%
    mutate(nipt_indiff = ifelse(round(ui1_Y, 5) == round(ui1_N, 5), 1, 0)) %>%
    mutate(pred_nipt = ifelse(nipt_indiff==0, 
                              ifelse(ui1_Y > ui1_N, 1, 0), 0))
  
  final <- final %>%
    mutate(pred_invasive = ifelse(pred_nipt==0, invasive_0, ifelse(sim_nipt_result == 1, invasive_P, invasive_N))) 
  
  if (setup != 1) {
    # for the actual counterfactuals
    final <- final %>% 
      dplyr::select(pregnancy, did_nipt, did_invasive, live_birth, p_i, fetus_risk, wave, oop_i, age, sim_positive, 
                    a_i, c_i, sim_wgt, sim_inv_mis, oop2_i, rec_i, ave_kub_age,
                    pred_nipt, pred_invasive, invasive_P, invasive_N, sim_nipt_result)
  }
  if (setup == 1) {
    # for drawing a_i and c_i and conditioning on the data
    final <- final %>% 
      dplyr::select(a_i, c_i, pred_nipt, pred_invasive)      
  }
  
  
  return(final)
}

outcomes <- function(par, data, prior){
  
  ##
  # Set up
  ##
  psi <- par[6]
  
  data <- data %>% 
    mutate(mother_prior = (!!prior)) %>% 
    mutate(p = ifelse(rec_i==1, mother_prior^psi, mother_prior)) %>%
    mutate(true_pr_pos = round( (1-p_i)*pfp + p_i*(1-pfn) , 10)) %>% # note: uses truth=p_i, not p (which is based on prior)
    mutate(pr_pos = round( (1-p)*pfp + p*(1-pfn) , 10)) %>%
    mutate(pr_c_pos = round( (1-pfn)*p/pr_pos , 10)) %>%
    mutate(pr_c_neg = round( pfn*p/(1-pr_pos) , 10)) 
  
  ##
  # 6 possible nodes 
  ##
  # Node 1: nipt==0 & invasive==0 (cannot miscarry)
  data <- data %>% 
    mutate(abort1 = ifelse(a_i > p*c_i, 1, 0)) %>%
    mutate(sim_live1 = ifelse(abort1==0, 1, 0)) %>%
    mutate(sim_abort1 = ifelse(abort1==1, 1, 0)) %>%
    mutate(sim_mis1 = 0) 
  
  # Node 2: nipt==Pos & invasive==0 (cannot miscarry)
  data <- data %>% 
    mutate(abort2 = ifelse(a_i > pr_c_pos*c_i, 1, 0)) %>%
    mutate(sim_live2 = ifelse(abort2==0, 1, 0)) %>%
    mutate(sim_abort2 = ifelse(abort2==1, 1, 0)) %>%
    mutate(sim_mis2 = 0) 
  
  # Node 3: nipt==Neg & invasive==0 (cannot miscarry)
  data <- data %>% 
    mutate(abort3 = ifelse(a_i > pr_c_neg*c_i, 1, 0)) %>%
    mutate(sim_live3 = ifelse(abort3==0, 1, 0)) %>%
    mutate(sim_abort3 = ifelse(abort3==1, 1, 0)) %>%
    mutate(sim_mis3 = 0) 
  
  # Nodes 4-6: any nipt, invasive==1 (can miscarry) - all three nodes same since invasive reveals truth 
  data <- data %>% 
    mutate(sim_mis4 = sim_inv_mis) %>% 
    mutate(sim_abort4 = ifelse(sim_inv_mis==0 & sim_positive==1 & a_i>c_i, 1, 0)) %>%
    mutate(sim_live4 = ifelse(sim_mis4==0 & sim_abort4==0, 1, 0))
  
  ##
  # For NIPT=1, where do you end up if positive vs. negative? 
  # use "invasive_P" and "invasive_N" decisions
  ##
  # if positive
  data <- data %>%
    mutate(sim_liveP = ifelse(invasive_P==0, sim_live2, sim_live4)) %>% 
    mutate(sim_abortP = ifelse(invasive_P==0, sim_abort2, sim_abort4)) %>% 
    mutate(sim_misP = ifelse(invasive_P==0, sim_mis2, sim_mis4))
  
  # if negative
  data <- data %>%
    mutate(sim_liveN  = ifelse(invasive_N==0, sim_live3, sim_live4)) %>% 
    mutate(sim_abortN = ifelse(invasive_N==0, sim_abort3, sim_abort4)) %>% 
    mutate(sim_misN   = ifelse(invasive_N==0, sim_mis3, sim_mis4))
  
  ##
  # Calculate outcomes
  data <- data %>%
    mutate(sim_live  = ifelse(pred_nipt==0, 
                              ifelse(pred_invasive==0, sim_live1, sim_live4),
                              ifelse(sim_nipt_result == 1, sim_liveP, sim_liveN))) %>%
    mutate(sim_abort = ifelse(pred_nipt==0, 
                              ifelse(pred_invasive==0, sim_abort1, sim_abort4),
                              ifelse(sim_nipt_result == 1, sim_abortP, sim_abortN))) %>%
    mutate(sim_mis   = ifelse(pred_nipt==0, 
                              ifelse(pred_invasive==0, sim_mis1, sim_mis4),
                              ifelse(sim_nipt_result == 1, sim_misP, sim_misN))) %>%
    mutate(sim_terminate = sim_abort + sim_mis) %>%
    mutate(sim_util = ifelse(sim_live == 0, a_i,
                             ifelse(sim_positive == 1, c_i, 0))) %>% 
    mutate(sim_util = ifelse(pred_nipt == 0, 
                             ifelse(pred_invasive == 0, sim_util, sim_util - oop2_i),
                             ifelse(pred_invasive == 0, sim_util - oop_i, sim_util - oop_i - oop2_i)))
  
  sims <- c("sim_live1", "sim_live2", "sim_live4", "sim_mis1", "sim_mis2", "sim_mis3", "sim_mis4", "sim_abort1", "sim_abort2", "sim_abort3", "sim_abort4", "sim_liveP", "sim_misP", "sim_abortP", "sim_liveN", "sim_misN", "sim_abortN", "invasive_P", "invasive_N", "pred_invasive")
  
  data <- data %>%
    mutate(sim_nipt = pred_nipt) %>% # just renaming 
    mutate(sim_invasive = pred_invasive) %>% # just renaming
    dplyr::select(pregnancy, sim_wgt, sim_positive, sim_nipt, sim_nipt_result, sim_invasive, sim_live, sim_abort, sim_mis, sim_terminate, sim_util, oop_i, oop2_i, a_i, c_i, p_i, fetus_risk, all_of(sims))
  
  return(data)
  
}


# conditional_draws <- function(par, data, x_vars, x_vars_reg, K=100, condition_data=1, f_a, f_c) {
#   print("Estimated parameters with Xs")
#   print(par)
#   actual <- data
#   actual_1 <- actual
  
#   #Step 0a: expand data K times
#   if(K > 1){
#     for (i in 1:(K-1)) {
#       actual <- rbind(actual, actual_1)
#     }
#   }
#   rm(i, actual_1)
#   actual <- actual %>% arrange(pregnancy)
  
#   #Step 0b: simulate positive indicator for chromAb 
#   set.seed(1292837)
#   actual <- actual %>% 
#     mutate(draw = runif(nrow(.), min=0, max=1)) %>%
#     mutate(sim_positive = ifelse(draw <= p_i, 1, 0)) %>% 
#     select(-draw)
  
#   #Step 0c: simulate result of NIPT test, if taken (using simulated chromAb indicator)
#   # added new mutate(sim_nipt_result) before select(-draw) to address dropping pregnancies in conditional counterfactual
#   actual <- actual %>% 
#     mutate(draw = runif(nrow(.), min=0, max=1)) %>%
#     mutate(sim_nipt_result = ifelse(sim_positive == 1, 
#                                     ifelse(draw > pfn, 1, 0),
#                                     ifelse(draw <= pfp, 1, 0))) %>% 
#     mutate(sim_nipt_result = ifelse(did_nipt == 1,
#                                     ifelse(did_invasive == 1, 1, 0),
#                                     sim_nipt_result)) %>%
#     select(-draw)
  
#   # Step 1: simulate K (a_i, c_i) draws per pregnancy (using parameter estimates) + calculate predicted testing decisions
#   # note: condition on the data = discard (a_i, c_i) draws that are rejected by the observed testing decisions
#   set.seed(101)
#   #f_a = ~ 0 + prev_mis_still
#   #f_c = ~ 0 + prev_mis_still
#   # create an exception flag for those in wave 2 that got NIPT
#   # why?: we can't condition on the data for these guys' NIPT decision because the model will never predict that outcome
#   actual <- actual %>%
#     mutate(exception = ifelse(wave==2 & did_nipt==1, 1, 0))
  
#   # simulate - first iteration
#   iter = 1
#   #draw <- draw_a_c_counter(par=est_par, data=actual)
#   draw <- draw_a_c_counter_xs(par=par, data=actual, f_a=f_a, f_c=f_c, x_vars=x_vars)
#   draw <- as.data.frame(draw) %>%
#     mutate(oop2_i = 0)
#   simulated <- model_decisions_counterfactuals(par=par, data=draw, prior=quo(p_i), setup = 1) %>%
#     as.data.frame() %>%
#     dplyr::select(a_i, c_i, pred_nipt, pred_invasive)
  
#   all <- cbind(actual, simulated) %>%
#     mutate(keeper = ifelse(did_nipt==pred_nipt,
#                            1,
#                            ifelse(exception==1, 1, 0))) %>%
#     mutate(keeper = ifelse(did_nipt==1,
#                            keeper,
#                            ifelse(did_invasive==pred_invasive, 1, 0))) %>%
#     group_by(pregnancy) %>%
#     mutate(share_keeper = mean(keeper)) %>%
#     mutate(num_keeper = sum(keeper)) %>%
#     ungroup()
  
#   rm(actual, simulated, draw)
  
#   # keep first iteration in a separate database
#   first <- all
  
#   # simulate - next iterations
  
#   while(iter < 250){
#     iter = iter + 1
    
#     good <- all %>% filter(share_keeper >= 0.1 | keeper==1) %>%
#       dplyr::select(-keeper, -share_keeper)
    
#     bad_act <- all %>% filter(share_keeper < 0.1 & keeper==0) %>%
#       dplyr::select(-a_i, -c_i, -pred_nipt, -pred_invasive, -keeper, -share_keeper)
#     #bad_draw <- draw_a_c_counter(par=est_par, data=bad_act)
#     bad_draw <- draw_a_c_counter_xs(par=par, data=bad_act, f_a=f_a, f_c=f_c, x_vars = x_vars)
#     bad_draw <- as.data.frame(bad_draw) %>%
#       mutate(oop2_i = 0)
    
#     bad_sim <- model_decisions_counterfactuals(par=par, data=bad_draw, prior=quo(p_i), setup = 1) %>%
#       as.data.frame() %>%
#       dplyr::select(a_i, c_i, pred_nipt, pred_invasive)
    
#     bad <- cbind(bad_act, bad_sim)
    
#     all <- rbind(good, bad) %>%
#       mutate(keeper = ifelse(did_nipt==pred_nipt,
#                              1,
#                              ifelse(exception==1, 1, 0))) %>%
#       mutate(keeper = ifelse(did_nipt==1,
#                              keeper,
#                              ifelse(did_invasive==pred_invasive, 1, 0))) %>%
#       arrange(pregnancy) %>%
#       group_by(pregnancy) %>%
#       mutate(share_keeper = mean(keeper)) %>%
#       mutate(num_keeper = sum(keeper)) %>%
#       ungroup()
#     rm(good, bad_act, bad_draw,  bad_sim, bad)
    
#   }
#   rm(iter)
  
#   # keep only keepers
#   summary(all$share_keeper)
#   table(all$keeper)
  
 
#   bad <- all %>% 
#     filter(num_keeper == 0) %>%
#     dplyr::select(pregnancy, did_nipt, did_invasive, fetus_risk, pred_invasive, pred_nipt,  p_i, lan, wave, oop_i, age, sim_positive, a_i, c_i, sim_nipt_result)
  
#   ### Write dataset with the bad draws
#   unique_bad <- bad %>%
#     dplyr::select(pregnancy, did_nipt, did_invasive, p_i, fetus_risk, oop_i) 
#   unique_bad <- unique(unique_bad)
#   ### Generate stats on pregnancies that we lose from conditioning before changing the conditioning procedure
#   bad_id <- bad %>%
#     select(pregnancy,p_i)
#   unique_bad <- unique(bad_id)
#   all_unique <- all %>%
#     select(pregnancy, p_i, num_keeper) 
#   # all_unique <- unique(all_unique)
#   # preg_count <-nrow(all_unique)
#   # bad_preg_count <- nrow(filter(all_unique, num_keeper == 0))
#   # good_preg_count <- nrow(filter(all_unique, num_keeper > 0))
#   # share_bad_unique <- bad_preg_count / preg_count
#   # share_good_unique <- 1 - share_bad_unique
#   # num_bad_low_risk <- nrow(filter(all_unique, num_keeper == 0 & p_i < 1/200))
#   # num_all_low_risk <-nrow(filter(all_unique, p_i < 1/200))
#   # num_good_low_risk <- nrow(filter(all_unique, num_keeper > 0 & p_i < 1/200))
#   # share_bad_low_risk <- num_bad_low_risk / bad_preg_count
#   # share_all_low_risk <- num_all_low_risk / preg_count
#   # share_good_low_risk <- num_good_low_risk / good_preg_count
#   # unique_preg_stats <- matrix(NA, nrow = 3, ncol = 4)
#   # unique_preg_stats[1,] <- c(bad_preg_count, share_bad_unique, num_bad_low_risk, share_bad_low_risk)
#   # unique_preg_stats[2,] <- c(good_preg_count, share_good_unique, num_good_low_risk, share_good_low_risk)
#   # unique_preg_stats[3,] <- c(preg_count, 1, num_all_low_risk, share_all_low_risk)
#   # #write.csv(unique_preg_stats, file=paste0(RESULTS, "/baby_model_nipt/pregnancy_stats.csv"), na=".", row.names=TRUE)
#   #####
  
#   bad <- bad %>% mutate(pref_chrom = ifelse(c_i >= a_i, 1, 0))
#   table(bad$pref_chrom)
#   table(bad$did_nipt)
#   table(bad$did_invasive)
#   print("finished with all the drawing iterations")
#   #write.dta(bad, paste0(TEMP, "/baby_model_nipt/bad_draws.dta"))
  
#   if (condition_data == 1) {
#     print("Condition on data")
#     all <- all %>% 
#       filter(keeper==1)
#     # create simulation weight = 1 / # keepers (same for every draw within a pregnancy)
#     all <- all %>% mutate(sim_wgt = 1 / (share_keeper * K))
#     table(all$keeper)
#   } else {
#     print("No condition on data") #  to check what happens if we don't condition on data
#     all <- all %>% mutate(sim_wgt = 1) 
#   }
  
#   ### Create conditional a_i, c_i draw sample for the regression for starting parameters
#   print("create conditional a_i, c_i draw sample for the regression for starting parameters")
#   bad_reg <- bad %>%
#     dplyr::select(pregnancy, a_i, c_i, did_nipt, did_invasive) %>%
#     group_by(pregnancy) %>%
#     filter(row_number() == 1) %>%
#     ungroup()
#   all_reg <- all %>%
#     dplyr::select(pregnancy, a_i, c_i, did_nipt, did_invasive) %>%
#     group_by(pregnancy) %>%
#     filter(row_number() == 1) %>%
#     ungroup()
#   reg <- rbind(all_reg, bad_reg)
#   print("created full reg sample")
#   #write.dta(reg, paste0(TEMP, "/baby_model_nipt/regression_sample_xs.dta"))
#   #print("saved regression sample")
#   reg_sample <- merge(reg, data, by = "pregnancy") %>%
#     dplyr::select(pregnancy, a_i, c_i, all_of(x_vars_reg))
#   formula_a <- as.formula(paste("a_i", paste(x_vars_reg, collapse = " + "), sep = " ~ "))
#   print("formula_a")
#   print(formula_a)
#   formula_c <- as.formula(paste("c_i", paste(x_vars_reg, collapse = " + "), sep = " ~ "))
  
#   a_i_mod <-lm(formula_a, data = reg_sample)
#   c_i_mod <-lm(formula_c, data = reg_sample)
#   sd_a <- sd(resid(a_i_mod))
#   sd_c <- sd(resid(c_i_mod))
#   mu_a <- tidy(a_i_mod)$estimate[1]
#   mu_c <- tidy(c_i_mod)$estimate[1]
#   num_xs = length(x_vars)
#   beta_as <- tidy(a_i_mod)$estimate[-1]
#   beta_cs <- tidy(c_i_mod)$estimate[-1]
#   print(x_vars_reg)
#   print("beta_as")
#   print(beta_as)
#   print("beta_cs")
#   print(beta_cs)
#   rho <- cor(resid(a_i_mod), resid(c_i_mod))
#   print("rho:")
#   print(rho)
#   reg_pars <- c(mu_a, mu_c, sd_a, sd_c, rho, beta_as, beta_cs)
#   #write.csv(reg_pars, file=paste0(TEMP, "/Prenatal/new_parameter_val/starting_parameter_val_new.csv"), na=".", row.names=FALSE)
#   print("reg pars here:")
#   print(reg_pars)
  
#   return(reg_pars)
# }

draw_a_c_counter_xs <- function(par, data, f_a, f_c, x_vars, J=1, unique_flag=1){
  
  a_mean <- par[1]
  c_mean <- par[2]
  a_sd <- par[3]
  c_sd <- par[4]
  rho <- par[5]
  
  # pull in data
  sim <- data %>% 
    dplyr::select(pregnancy, p_i, wave, oop_i, bin_number, policy_id, policy_regime, sim_positive, sim_nipt_result, all_of(x_vars))
  sim <- as.matrix(sim)
  colnames(sim) <- c("pregnancy", "p_i", "wave", "oop_i", "bin_number", "policy_id", "policy_regime", "sim_positive", "sim_nipt_result", x_vars)
  
  # replicate J times to make J draws
  sim_1 <- sim
  if(J > 1){
    for (i in 1:(J-1)) {
      sim <- rbind(sim, sim_1)
    }
  }
  
  # draw a's and c's
  # set parameters for bivariate normal distribution
  sigma <- matrix(c(a_sd^2, a_sd*c_sd*rho, a_sd*c_sd*rho, c_sd^2), 2, 2) # Covariance matrix
  sigma <- make.positive.definite(sigma, tol=1e-8) # Fix for sigma sometimes not being positive definite due to optimization
  
  tryCatch({
    low <- c(-Inf, -Inf)
    upp <- c(0,Inf)
    N <- nrow(sim)
    
    #set.seed(3491)
    set.seed(238374)
    sim <- data.frame(sim)
    # Apply mean shifters
    N_a = ncol(model.matrix(f_a, sim))
    N_c = ncol(model.matrix(f_c, sim))
    
    a_coefs = par[7:(7 + N_a - 1)]
    c_coefs = par[(7+N_a):length(par)]
    
    # print("mean shift (a,c):")
    # print(a_coefs)
    # print(c_coefs)
    a_mean_shifted = a_mean + apply_meanshift(f_a, data, a_coefs)
    c_mean_shifted = c_mean + apply_meanshift(f_c, data, c_coefs)
    
    ### new code
    sim <- cbind(sim, a_mean_shifted, c_mean_shifted) %>%
      mutate(original_order = row_number())
    # xs_only <- sim %>%
    #   select(all_of(x_vars), a_mean_shifted, c_mean_shifted)
    # unique_xs <- unique(xs_only)
    
    mu <- cbind(a_mean_shifted, c_mean_shifted) 
    mu <- mu %>%
      as.data.frame(mu) %>%
      group_by(V1, V2) %>%
      mutate(group = cur_group_id()) %>%
      ungroup()

    if (unique_flag == 1) {
      #mu_unique <- unique(mu)
      # sample <- mu_unique
      # mu_unique <- as.data.frame(mu_unique) %>%
      #   dplyr::mutate(group = row_number())
      # mu <- merge(mu, mu_unique, by=c("V1","V2"))
      # colnames(mu) <- c("a_mean_shifted", "c_mean_shifted", "group")
      # colnames(mu_unique) <- c("a_mean_shifted", "c_mean_shifted", "group")
      # sim <- merge(sim, mu_unique, by = c("a_mean_shifted", "c_mean_shifted"))
      mu_unique <- unique(mu)
      sample <- mu_unique
      if (nrow(mu) == nrow(data)) {
        print("right length")
      }
      print(nrow(mu))
      print(nrow(data))
      mu_leng <- nrow(mu)
      colnames(mu) <- c("a_mean_shifted", "c_mean_shifted", "group")
      colnames(mu_unique) <- c("a_mean_shifted", "c_mean_shifted", "group")
      sim <- sim %>% 
        group_by(a_mean_shifted, c_mean_shifted) %>%
        mutate(group = cur_group_id()) %>%
        ungroup()
      sample <-sample %>%
        select(-group)
    } else {
      sample <- mu
    }
    # Draw from the untruncated distribution first, check truncation
    rejection_rates = t(apply(
      sample,
      1,
      function(x){
        draws = mvrnorm(n=1000, mu=x, Sigma=sigma)
        #reject_rate = mean(rowSums(draws >= upp) > 0)
        reject_rate = mean(draws[,1] >= 0)
        return(reject_rate)
      }
    ))
    
    # Check the rejection rates -- if any are excess of 95%, reject.
    all_rejection_rates = sum(rejection_rates >= 0.95)
    
    ### move seed to right before draws
    set.seed(3491)
    if (all_rejection_rates > 0) {
      # The parameters are garbage, set the draws to the bounds.
      bvn <- matrix(upp, nrow=nrow(mu), ncol=2, byrow=TRUE)
      colnames(bvn) <- c("a_i", "c_i")
      sim <- data.frame(cbind(sim, bvn))
    } else {
      # all good, draw from truncated bivariate normal distribution
      # maybe faster to instead, draw n times from x number of unique distributions, matching each obs to its mu_unique group
      max_group <- max(mu$group)
      for (i in 1:max_group) {
        mu_temp <- mu %>%
          filter(group == i)
        mu_temp <- mu_temp %>%
          dplyr::select(-group)
        samplesize <- nrow(mu_temp)
        ### take the means and sigma from each group. Then, create a distribution and sample n random values. Then merge them onto mu_temp
        mean_dist <- mu_unique[i,] %>%
          dplyr::select(-group)
        mean_dist_vec <- unname(unlist(mean_dist[1,]))
        bvn_temp <- rtmvnorm(n=samplesize, mean = mean_dist_vec, sigma = sigma, lower = low, upper = upp, algorithm="rejection")
        bvn_temp <- as.data.frame(bvn_temp)
        bvn_temp <- bvn_temp %>%
          mutate(group = i) %>%
          mutate(n = row_number())
        if (i == 1) {
          bvn <- bvn_temp
        } else {
          bvn <- rbind(bvn, bvn_temp)
        }
      }
      sim <- sim %>%
        group_by(group) %>%
        mutate(n = row_number()) %>%
        ungroup() %>% 
        dplyr::inner_join(bvn, by = c("group", "n")) %>%
        rename(a_i = V1) %>%
        rename(c_i = V2) %>%
        ungroup() %>%
        arrange(original_order) %>%
        select(-a_mean_shifted, -c_mean_shifted, -group, -n, -original_order)
    }
    if (mu_leng != nrow(sim)) {
      print("wrong length")
    }
    
    # keep variables necessary for simulating decisions
    out <- cbind(sim)
    
    return(out)},
    error = function(x){
      print(x)
      return(NA)
    }
  )
}

mean_split <- function(var, wgt) {
  temp_var = var
  temp_wgt = wgt
  temp = data.frame(temp_var, temp_wgt)
  temp <- temp %>% mutate(w_var = var * temp_wgt)
  numer = sum(temp$w_var)
  denom = sum_wgts
  ans = numer / denom
  return(ans)
}

# Function that calculates shares of observations with each outcome (for table)
outcomes_shares_split <- function(data){
  data <- data %>% 
    mutate(sim_subt = ifelse(sim_nipt == 1 | sim_invasive == 1, 1, 0)) %>%
    mutate(sim_nipt_only = ifelse(sim_nipt == 1 & sim_invasive == 0, 1, 0)) %>%
    mutate(sim_inv_only = ifelse(sim_nipt == 0 & sim_invasive == 1, 1, 0)) %>%
    mutate(sim_both_tests = ifelse(sim_nipt == 1 & sim_invasive == 1, 1, 0))
  
  share_subt <- mean_split(data$sim_subt, data$sim_wgt) * 100
  share_nipt     <- mean_split(data$sim_nipt, data$sim_wgt) * 100
  share_invasive <- mean_split(data$sim_invasive, data$sim_wgt) * 100
  share_nipt_only     <- mean_split(data$sim_nipt_only, data$sim_wgt) * 100
  share_inv_only <- mean_split(data$sim_inv_only, data$sim_wgt) * 100
  share_both_tests <- mean_split(data$sim_both_tests, data$sim_wgt) * 100
  
  share_live <- mean_split(data$sim_live, data$sim_wgt) * 100
  share_terminate <- mean_split(data$sim_terminate, data$sim_wgt) * 100
  share_abort <- mean_split(data$sim_abort, data$sim_wgt) * 100
  share_mis <- mean_split(data$sim_mis, data$sim_wgt) * 100
  
  data <- data %>%
    # live birth with chrom abn
    mutate(outcome1 = ifelse((sim_live==1 & sim_positive==1), 1, 0)) %>%
    # live birth without chrom abn
    mutate(outcome2 = ifelse((sim_live==1 & sim_positive==0), 1, 0)) %>%
    # terminated pregnancy with chrom abn
    mutate(outcome3 = ifelse((sim_terminate==1 & sim_positive==1), 1, 0)) %>%
    # terminated pregnancy without chrom abn
    mutate(outcome4 = ifelse((sim_terminate==1 & sim_positive==0), 1, 0)) 
  
  share_outcome1 <- mean_split(data$outcome1, data$sim_wgt) * 100
  share_outcome2 <- mean_split(data$outcome2, data$sim_wgt) * 100
  share_outcome3 <- mean_split(data$outcome3, data$sim_wgt) * 100 
  share_outcome4 <- mean_split(data$outcome4, data$sim_wgt) * 100
  
  data <- data %>% 
    # live birth w CA & a >= c
    mutate(outcome1a = ifelse((sim_live==1 & sim_positive==1 & a_i >= c_i), 1, 0)) %>%
    # live birth w CA & c > a
    mutate(outcome1b = ifelse((sim_live==1 & sim_positive==1 & a_i < c_i), 1, 0)) %>%
    # terminated w CA & a >= c
    mutate(outcome3a = ifelse((sim_terminate==1 & sim_positive==1 & a_i >= c_i), 1, 0)) %>%
    # terminated w CA & c > a
    mutate(outcome3b = ifelse((sim_terminate==1 & sim_positive==1 & a_i < c_i), 1, 0)) %>%
    # bad outcome (terminated preg w/o CA + terminated preg w CA & c > a + live birth w CA & a >= c)
    mutate(bad_outcome = ifelse((outcome4 == 1 | outcome3b == 1 | outcome1a == 1), 1, 0))
  
  share_outcome1a <- mean_split(data$outcome1a, data$sim_wgt) * 100
  share_outcome1b <- mean_split(data$outcome1b, data$sim_wgt) * 100
  share_outcome3a <- mean_split(data$outcome3a, data$sim_wgt) * 100
  share_outcome3b <- mean_split(data$outcome3b, data$sim_wgt) * 100
  share_bad_outcome <- mean_split(data$bad_outcome, data$sim_wgt) * 100 
  
  data <- data %>% 
    mutate(g_subsidy_inv = ifelse(oop2_i == Inf, 0, g_cost_inv - oop2_i)) %>%
    mutate(g_subsidy_nipt = ifelse(oop_i == Inf, 0, g_cost_nipt - oop_i)) %>%
    mutate(cost_nipt_i = sim_nipt * (g_subsidy_nipt)) %>%
    mutate(cost_inv_i = sim_invasive * (g_subsidy_inv)) %>%
    mutate(cost_testing_i = cost_nipt_i + cost_inv_i) %>% 
    mutate(pays_oop_nipt = ifelse(oop_i != 0 & oop_i != Inf & sim_nipt == 1, 1 ,0))
  
  cost_testing <- mean_split(data$cost_testing_i, data$sim_wgt)    
  cost_nipt <-mean_split(data$cost_nipt_i, data$sim_wgt)    
  cost_inv <-mean_split(data$cost_inv_i, data$sim_wgt)    
  surplus <- mean_split(data$sim_util, data$sim_wgt)   
  share_nipt_oop <- mean_split(data$pays_oop_nipt, data$sim_wgt)
  
  # return 
  shares <- rep(NA, 17)
  shares[1] <- round(share_subt, 2)
  shares[2] <- round(share_nipt_only, 2)
  shares[3] <- round(share_inv_only, 2)
  shares[4] <- round(share_both_tests, 2)
  shares[5] <- round(share_live, 2)
  shares[6] <- round(share_outcome1, 2)
  shares[7] <- round(share_outcome1a, 2)
  shares[8] <- round(share_outcome1b, 2)
  shares[9] <- round(share_outcome2, 2)
  shares[10] <- round(share_terminate, 2)
  shares[11] <- round(share_outcome3, 2)
  shares[12] <- round(share_outcome4, 2)
  shares[13] <- round(share_bad_outcome, 2)
  shares[14] <- round(cost_testing, 2)
  shares[15] <- round(cost_nipt, 2)
  shares[16] <- round(cost_inv, 2)
  shares[17] <- round(surplus, 2)
  shares[18] <- round(share_nipt_oop, 2)
  return(shares)
  
}

outcomes_shares <- function(data){
  data <- data %>% 
    mutate(sim_subt = ifelse(sim_nipt == 1 | sim_invasive == 1, 1, 0)) %>%
    mutate(sim_nipt_only = ifelse(sim_nipt == 1 & sim_invasive == 0, 1, 0)) %>%
    mutate(sim_inv_only = ifelse(sim_nipt == 0 & sim_invasive == 1, 1, 0)) %>%
    mutate(sim_both_tests = ifelse(sim_nipt == 1 & sim_invasive == 1, 1, 0))
  
  share_subt <- weighted.mean(data$sim_subt, data$sim_wgt) * 100
  share_nipt     <- weighted.mean(data$sim_nipt, data$sim_wgt) * 100
  share_invasive <- weighted.mean(data$sim_invasive, data$sim_wgt) * 100
  share_nipt_only     <- weighted.mean(data$sim_nipt_only, data$sim_wgt) * 100
  share_inv_only <- weighted.mean(data$sim_inv_only, data$sim_wgt) * 100
  share_both_tests <- weighted.mean(data$sim_both_tests, data$sim_wgt) * 100
  
  share_live <- weighted.mean(data$sim_live, data$sim_wgt) * 100
  share_terminate <- weighted.mean(data$sim_terminate, data$sim_wgt) * 100
  share_abort <- weighted.mean(data$sim_abort, data$sim_wgt) * 100
  share_mis <- weighted.mean(data$sim_mis, data$sim_wgt) * 100
  
  data <- data %>%
    # live birth with chrom abn
    mutate(outcome1 = ifelse((sim_live==1 & sim_positive==1), 1, 0)) %>%
    # live birth without chrom abn
    mutate(outcome2 = ifelse((sim_live==1 & sim_positive==0), 1, 0)) %>%
    # terminated pregnancy with chrom abn
    mutate(outcome3 = ifelse((sim_terminate==1 & sim_positive==1), 1, 0)) %>%
    # terminated pregnancy without chrom abn
    mutate(outcome4 = ifelse((sim_terminate==1 & sim_positive==0), 1, 0)) 
  
  share_outcome1 <- weighted.mean(data$outcome1, data$sim_wgt) * 100
  share_outcome2 <- weighted.mean(data$outcome2, data$sim_wgt) * 100
  share_outcome3 <- weighted.mean(data$outcome3, data$sim_wgt) * 100
  share_outcome4 <- weighted.mean(data$outcome4, data$sim_wgt) * 100
  
  data <- data %>% 
    # live birth w CA & a >= c
    mutate(outcome1a = ifelse((sim_live==1 & sim_positive==1 & a_i >= c_i), 1, 0)) %>%
    # live birth w CA & c > a
    mutate(outcome1b = ifelse((sim_live==1 & sim_positive==1 & a_i < c_i), 1, 0)) %>%
    # terminated w CA & a >= c
    mutate(outcome3a = ifelse((sim_terminate==1 & sim_positive==1 & a_i >= c_i), 1, 0)) %>%
    # terminated w CA & c > a
    mutate(outcome3b = ifelse((sim_terminate==1 & sim_positive==1 & a_i < c_i), 1, 0)) %>%
    # bad outcome (terminated preg w/o CA + terminated preg w CA & c > a + live birth w CA & a >= c)
    mutate(bad_outcome = ifelse((outcome4 == 1 | outcome3b == 1 | outcome1a == 1), 1, 0))
  
  share_outcome1a <- weighted.mean(data$outcome1a, data$sim_wgt) * 100
  share_outcome1b <- weighted.mean(data$outcome1b, data$sim_wgt) * 100
  share_outcome3a <- weighted.mean(data$outcome3a, data$sim_wgt) * 100
  share_outcome3b <- weighted.mean(data$outcome3b, data$sim_wgt) * 100
  share_bad_outcome <- weighted.mean(data$bad_outcome, data$sim_wgt) * 100 
  
  data <- data %>% 
    mutate(g_subsidy_inv = ifelse(oop2_i == Inf, 0, g_cost_inv - oop2_i)) %>%
    mutate(g_subsidy_nipt = ifelse(oop_i == Inf, 0, g_cost_nipt - oop_i)) %>%
    mutate(cost_nipt_i = sim_nipt * (g_subsidy_nipt)) %>%
    mutate(cost_inv_i = sim_invasive * (g_subsidy_inv)) %>%
    mutate(cost_testing_i = cost_nipt_i + cost_inv_i) %>%
    mutate(pays_oop_nipt = ifelse(oop_i != 0 & oop_i != Inf & sim_nipt == 1, 1 ,0))
  
  cost_testing <-weighted.mean(data$cost_testing_i, data$sim_wgt)   
  cost_nipt <-weighted.mean(data$cost_nipt_i, data$sim_wgt)
  cost_inv <-weighted.mean(data$cost_inv_i, data$sim_wgt)
  surplus <- weighted.mean(data$sim_util, data$sim_wgt)
  share_nipt_oop <- weighted.mean(data$pays_oop_nipt, data$sim_wgt)
  
  # return 
  shares <- rep(NA, 17)
  shares[1] <- round(share_subt, 2)
  shares[2] <- round(share_nipt_only, 2)
  shares[3] <- round(share_inv_only, 2)
  shares[4] <- round(share_both_tests, 2)
  shares[5] <- round(share_live, 2)
  shares[6] <- round(share_outcome1, 2)
  shares[7] <- round(share_outcome1a, 2)
  shares[8] <- round(share_outcome1b, 2)
  shares[9] <- round(share_outcome2, 2)
  shares[10] <- round(share_terminate, 2)
  shares[11] <- round(share_outcome3, 2)
  shares[12] <- round(share_outcome4, 2)
  shares[13] <- round(share_bad_outcome, 2)
  shares[14] <- round(cost_testing, 2)
  shares[15] <- round(cost_nipt, 2)
  shares[16] <- round(cost_inv, 2)
  shares[17] <- round(surplus, 2)
  shares[18] <- round(share_nipt_oop, 2)
  return(shares)
  
}